{
 "metadata": {},
 "nbformat": 3,
 "nbformat_minor": 0,
 "worksheets": [
  {
   "cells": [
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "Multidot\n",
      "=========\n",
      "\n",
      "The matrix multiplication function, numpy.dot(), only takes two\n",
      "arguments. That means to multiply more than two arrays together you end\n",
      "up with nested function calls which are hard to read:"
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "  dot(dot(dot(a,b),c),d)"
     ],
     "language": "python",
     "metadata": {},
     "outputs": []
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "versus infix notation where you'd just be able to write"
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "  a*b*c*d"
     ],
     "language": "python",
     "metadata": {},
     "outputs": []
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "There are a couple of ways to define an 'mdot' function that acts like\n",
      "dot but accepts more than two arguments. Using one of these allows you\n",
      "to write the above expression as"
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "mdot(a,b,c,d)"
     ],
     "language": "python",
     "metadata": {},
     "outputs": []
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "Using reduce\n",
      "------------\n",
      "\n",
      "The simplest way it to just use reduce."
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "def mdot(*args):\n",
      "    return reduce(numpy.dot, args)"
     ],
     "language": "python",
     "metadata": {},
     "outputs": []
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "Or use the equivalent loop (which is apparently the preferred style [for\n",
      "Py3K](http://www.python.org/dev/peps/pep-3100/#id53)):"
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "def mdot(*args):\n",
      "    ret = args[0]\n",
      "    for a in args[1:]:\n",
      "        ret = dot(ret,a)\n",
      "    return ret"
     ],
     "language": "python",
     "metadata": {},
     "outputs": []
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "This will always give you left to right associativity, i.e. the\n",
      "expression is interpreted as \\`(((a\\*b)\\*c)\\*d)\\`.\n",
      "\n",
      "You also can make a right-associative version of the loop:"
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "def mdotr(*args):\n",
      "    ret = args[-1]\n",
      "    for a in reversed(args[:-1]):\n",
      "        ret = dot(a,ret)\n",
      "    return ret"
     ],
     "language": "python",
     "metadata": {},
     "outputs": []
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "which evaluates as \\`(a\\*(b\\*(c\\*d)))\\`. But sometimes you'd like to\n",
      "have finer control since the order in which matrix multiplies are\n",
      "performed can have a big impact on performance. The next version gives\n",
      "that control.\n",
      "\n",
      "Controlling order of evaluation\n",
      "-------------------------------\n",
      "\n",
      "If we're willing to sacrifice Numpy's ability to treat tuples as arrays,\n",
      "we can use tuples as grouping constructs. This version of \\`mdot\\`\n",
      "allows syntax like this:"
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "   mdot(a,((b,c),d))"
     ],
     "language": "python",
     "metadata": {},
     "outputs": []
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "to control the order in which the pairwise \\`dot\\` calls are made."
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "import types\n",
      "import numpy\n",
      "def mdot(*args):\n",
      "   \"\"\"Multiply all the arguments using matrix product rules.\n",
      "   The output is equivalent to multiplying the arguments one by one\n",
      "   from left to right using dot().\n",
      "   Precedence can be controlled by creating tuples of arguments,\n",
      "   for instance mdot(a,((b,c),d)) multiplies a (a*((b*c)*d)).\n",
      "   Note that this means the output of dot(a,b) and mdot(a,b) will differ if\n",
      "   a or b is a pure tuple of numbers.\n",
      "   \"\"\"\n",
      "   if len(args)==1:\n",
      "       return args[0]\n",
      "   elif len(args)==2:\n",
      "       return _mdot_r(args[0],args[1])\n",
      "   else:\n",
      "       return _mdot_r(args[:-1],args[-1])\n",
      "\n",
      "def _mdot_r(a,b):\n",
      "   \"\"\"Recursive helper for mdot\"\"\"\n",
      "   if type(a)==types.TupleType:\n",
      "       if len(a)>1:\n",
      "           a = mdot(*a)\n",
      "       else:\n",
      "           a = a[0]\n",
      "   if type(b)==types.TupleType:\n",
      "       if len(b)>1:\n",
      "           b = mdot(*b)\n",
      "       else:\n",
      "           b = b[0]\n",
      "   return numpy.dot(a,b)"
     ],
     "language": "python",
     "metadata": {},
     "outputs": []
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "Multiply\n",
      "--------\n",
      "\n",
      "Note that the elementwise multiplication function \\`numpy.multiply\\` has\n",
      "the same two-argument limitation as \\`numpy.dot\\`. The exact same\n",
      "generalized forms can be defined for multiply.\n",
      "\n",
      "Left associative versions:"
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "def mmultiply(*args):\n",
      "    return reduce(numpy.multiply, args)"
     ],
     "language": "python",
     "metadata": {},
     "outputs": []
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "def mmultiply(*args):\n",
      "    ret = args[0]\n",
      "    for a in args[1:]:\n",
      "        ret = multiply(ret,a)\n",
      "    return ret"
     ],
     "language": "python",
     "metadata": {},
     "outputs": []
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "Right-associative version:"
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "def mmultiplyr(*args):\n",
      "    ret = args[-1]\n",
      "    for a in reversed(args[:-1]):\n",
      "        ret = multiply(a,ret)\n",
      "    return ret"
     ],
     "language": "python",
     "metadata": {},
     "outputs": []
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "Version using tuples to control order of evaluation:"
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "import types\n",
      "import numpy\n",
      "def mmultiply(*args):\n",
      "   \"\"\"Multiply all the arguments using elementwise product.\n",
      "   The output is equivalent to multiplying the arguments one by one\n",
      "   from left to right using multiply().\n",
      "   Precedence can be controlled by creating tuples of arguments,\n",
      "   for instance mmultiply(a,((b,c),d)) multiplies a (a*((b*c)*d)).\n",
      "   Note that this means the output of multiply(a,b) and mmultiply(a,b) will differ if\n",
      "   a or b is a pure tuple of numbers.\n",
      "   \"\"\"\n",
      "   if len(args)==1:\n",
      "       return args[0]\n",
      "   elif len(args)==2:\n",
      "       return _mmultiply_r(args[0],args[1])\n",
      "   else:\n",
      "       return _mmultiply_r(args[:-1],args[-1])\n",
      "\n",
      "def _mmultiply_r(a,b):\n",
      "   \"\"\"Recursive helper for mmultiply\"\"\"\n",
      "   if type(a)==types.TupleType:\n",
      "       if len(a)>1:\n",
      "           a = mmultiply(*a)\n",
      "       else:\n",
      "           a = a[0]\n",
      "   if type(b)==types.TupleType:\n",
      "       if len(b)>1:\n",
      "           b = mmultiply(*b)\n",
      "       else:\n",
      "           b = b[0]\n",
      "   return numpy.multiply(a,b)"
     ],
     "language": "python",
     "metadata": {},
     "outputs": []
    }
   ],
   "metadata": {}
  }
 ]
}